Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.6】 为 Paddle 增强put_along_axis API -part #59674

Merged
merged 25 commits into from
Dec 13, 2023

Conversation

YibinLiu666
Copy link
Contributor

@YibinLiu666 YibinLiu666 commented Dec 4, 2023

PR types

New features

PR changes

APIs

Description

这个PR的改动为:

  1. 按照RFC增强put_along_axis算子,支持min、max、mean规约方式。【Hackathon 5th No.6】 为 Paddle 增强put_along_axis API community#636
  2. 修复了已有add、mul规约方式梯度计算错误的问题。
  3. 为paddle实现了底层的原子乘操作,修复了GPU上乘法计算错误的bug。put_along_axis reduce='mul' 结果不对, cpu正确,gpu错误 #52446
  4. 此PR是在fix behavior of put_along_axis and take_along_axis 易用性提升No.43 #59163 基础上进行的修改

@paddle-bot paddle-bot bot added the contributor External developers label Dec 4, 2023
Copy link

paddle-bot bot commented Dec 5, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@YibinLiu666
Copy link
Contributor Author

@zoooo0820 CI都过了,能否麻烦review一下,改动有点大,麻烦您了

do {
assumed = old;
old = atomicCAS(address, assumed, val * assumed);
} while (assumed != old);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否也应该有一个返回值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -395,6 +395,181 @@ CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
CudaAtomicAdd(imag, val.imag));
}

// For atomicMul.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些atomicMul的计算算法,能提供下参考吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的atomicMul都是参考的前面的atomicAdd以及后面的atomicMin这些,只是把加改成了乘

):
tensor = paddle.to_tensor(input)
else:
tensor = input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这组atleast_xxx的改动是否和这个PR无关

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该是之前解决冲突的时候导致的,develop分支的代码是绿色部分,不知道为啥这里显示是我这个PR里面的改动,我再改改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在没有这个改动了

*out, axis, index, value, include_self, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::gpu_scatter_min_kernel<T, int64_t>(
*out, axis, index, value, include_self, dev_ctx);
}
} else {
PADDLE_THROW(errors::InvalidArgument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

报错信息也更新下吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

*out, axis, index, value, include_self, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::cpu_scatter_min_kernel<T, int64_t>(
*out, axis, index, value, include_self, dev_ctx);
}
} else {
PADDLE_THROW(errors::InvalidArgument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

报错信息可以更新下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

int64_t index_idx = 0;
int* num_elements = new int[grad_size]();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处的num_elements好像没有delete。看起来该项的几处用法都是数组语义,是否可以考虑用stl替代下,更安全一些

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经改成了std::vector

@@ -268,9 +489,6 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self,
outer_dim_size_grad *= grad_dims[i];
}
int64_t index_idx = index.numel() - 1;
for (int i = 0; i < grad_size; i++) {
grad_data[i] = static_cast<tensor_t>(0);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

想确认下,此处的移除是因为这段赋值是非必要的,还是历史行为是错误的?

Copy link
Contributor Author

@YibinLiu666 YibinLiu666 Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一段是因为value的grad没有初始化为0,我在 #59163 这个PR里面加的,我现在把value grad初始化为0挪到了put_along_axis_grad_kernel.cc以及put_along_axis_grad_kernel.cu初始化这个grad的时候。

*x_grad,
include_self,
dev_ctx);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_type现在在外部有检查吗,这里是只会有int32/int64两种情况吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前就是只有这两种情况,我再加个检查

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经在python接口加了类型检查

self.axis_type = "int64"


class TestPutAlongAxisOpMulIncludeSelf(TestPutAlongAxisOp):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个includeself的类测试的是false的情况,命名可以加个not,更清晰一些

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.dtype = 'float64'
self.x_type = "float64"
self.x_shape = (10, 10, 10)
self.value_type = "float64"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU的场景因为添加了很多dtype的atomicMul的方法,这里能否在单测中补充下目前支持的几个数据类型在新增的reduce方案下的case,保证新增的方案是正确的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于Optest不支持int类型的输入,所以我统一用的unittest在前向计算做的测试,目前已经加了float32 bfloat16 int32 int64的测试,都没问题

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于Optest不支持int类型的输入,所以我统一用的unittest在前向计算做的测试,目前已经加了float32 bfloat16 int32 int64的测试,都没问题

因为前面代码中包含一个uint8的特殊情况,能再辛苦补充一个uint8类型的测试吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@YibinLiu666
Copy link
Contributor Author

@zoooo0820 CI现在都过了,能否辛苦您再review一下,看看还有哪里需要再改改,麻烦您了

phi::CudaAtomicMax(self_data, *src_data);
}
template <typename tensor_t,
std::enable_if_t<std::is_same<tensor_t, uint8_t>::value>* = nullptr>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问此处是因为uint8没有对应的atomic操作,所以做的特殊处理吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,没有实现uint8的atomic操作,这里的定义我是仿照这个文件之前就定义的reduce_add写的。

self.dtype = 'float64'
self.x_type = "float64"
self.x_shape = (10, 10, 10)
self.value_type = "float64"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于Optest不支持int类型的输入,所以我统一用的unittest在前向计算做的测试,目前已经加了float32 bfloat16 int32 int64的测试,都没问题

因为前面代码中包含一个uint8的特殊情况,能再辛苦补充一个uint8类型的测试吗

Copy link
Contributor

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@vivienfanghuagood vivienfanghuagood left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for api change

@@ -2035,7 +2035,7 @@
backward : psroi_pool_grad

- op : put_along_axis
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign")
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign", bool include_self = true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a parameter of broadcast in Python API, shall we also add it here as include_self or delete it from API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast is processed in the Python interface, so there is no need to pass it into the C interface again

@jeff41404
Copy link
Contributor

The docstring needs to be modified, such as missing an introduction to the parameter of values, and the example code is too simple, requiring the addition of example code for multiple usage methods. also modifying Chinese documents.

@YibinLiu666
Copy link
Contributor Author

YibinLiu666 commented Dec 12, 2023

The docstring needs to be modified, such as missing an introduction to the parameter of values, and the example code is too simple, requiring the addition of example code for multiple usage methods. also modifying Chinese documents.

Done for doc of en. Chinese doc is modified at PaddlePaddle/docs#6348

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
一些typo小问题,新提一个 PR 修改吧 @YibinLiu666

reduce (str, optional): The reduce operation, default is 'assign', support 'add', 'assign', 'mul' and 'multiply'.
include_self (bool, optional): whether to reduce with the elements of arr. (Only support True now)
broadcast (bool, optional): whether to broadcast indices.
reduce (str, optional): The reduce operation, default is 'assign', support 'add', 'assign', 'mul', 'multiply', "mean", "amin" and "amax".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
reduce (str, optional): The reduce operation, default is 'assign', support 'add', 'assign', 'mul', 'multiply', "mean", "amin" and "amax".
reduce (str, optional): The reduce operation, default is 'assign', support 'add', 'assign', 'mul', 'multiply', 'mean', 'amin' and 'amax'.

和前文统一吧

@luotao1 luotao1 changed the title 【Hackathon 5th No.6】 为 Paddle 增强put_along_axis API 【Hackathon 5th No.6】 为 Paddle 增强put_along_axis API -part Dec 13, 2023
@luotao1 luotao1 merged commit c35c63e into PaddlePaddle:develop Dec 13, 2023
29 checks passed
@YibinLiu666 YibinLiu666 deleted the put_along_axis2 branch December 13, 2023 08:52
XiaoguangHu01 pushed a commit that referenced this pull request Feb 28, 2024
* 【Hackathon 5th No.6】 为 Paddle 增强put_along_axis API -part (#59674)

* fix bug of put_along_axis (#60551)

* Improve the performence of put_along_axis (#60618)

* fix bug of put_along_axis

* improve performence of put_along_axis

* [Bug-Fix] fix compile bug of cudaxxxAsync (#60934)

---------

Co-authored-by: YibLiu <68105073+YibinLiu666@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants